from pyspark.sql import SQLContext
from common.config.config import *
from pyspark.ml.feature import VectorAssembler
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType
from pyspark.sql import functions as F
from pyspark.sql.types import StringType
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
import pandas as pd
import difflib

import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")

label_l = [0, 1, 2]

def read_gp_spark(spark, table_name):
    return spark.read.jdbc(url=GP_URL, table=table_name, properties=GP_PROP)


def save_gp_spark(dataframe, table_name):
    dataframe.write.jdbc(GP_URL, table_name, "overwrite", GP_PROP)


# TODO optimization
def get_class_model_info(sc, model_id):
    sql_context = SQLContext(sc)
    df = sql_context.read.jdbc(url=MYSQL_URL, table="model", properties=MYSQL_PROP)
    sqlfindpath = "id = " + str(model_id)
    model_path = df.where(sqlfindpath).select(df["model_saved_path"]).collect()[0][0]
    other_info = df.where(sqlfindpath).select(df["other_info"]).collect()[0][0]
    other_info = eval(other_info)
    label_list = other_info["label_list"]
    feature_list = other_info["feature_list"]
    return model_path, feature_list, label_list


def get_model_info(sc, model_id):
    sql_context = SQLContext(sc)
    df = sql_context.read.jdbc(url=MYSQL_URL, table="model", properties=MYSQL_PROP)
    sqlfindpath = "id = " + str(model_id)
    model_path = df.where(sqlfindpath).select(df["model_saved_path"]).collect()[-1][0]
    other_info = df.where(sqlfindpath).select(df["other_info"]).collect()[0][0]
    other_info = eval(other_info)
    feature_list = other_info["feature_list"]
    return model_path, feature_list


def get_distincts(sc, table_name, column):
    sql_context = SQLContext(sc)
    df = sql_context.read.jdbc(url=GP_URL, table=table_name, properties=GP_PROP)
    # sqlfindpath = "id = "" + str( + """)
    distincts = df.select(column).distinct().collect()
    dlist = []
    for i in distincts:
        dlist.append(i[0])
    dlist.sort()
    return dlist


# todo more advanced method
def string_similar(s1, s2):
    return difflib.SequenceMatcher(None, s1, s2).ratio()


@udf
def to_index(x):
    if x is not None:
        return label_l.index(x)


@udf
def to_label(x):
    if x is not None:
        return label_l[int(x)]


@udf
def one_hot(x, y):
    # global category_cur
    if x is not None:
        if x == y:
            return 1
        else:
            return 0


def train_test_split(df, split_rate):
    seed = 123
    # split data
    train, test = df.randomSplit(split_rate, seed=seed)
    return train, test


def cast_to_float(df, column):
    # print (df.select(column).collect()[0][1])
    output = df.withColumn((column + "_Int"), df[column].cast(FloatType())).drop(column)
    output = output.withColumnRenamed((column + "_Int"), column)
    return output


def cast_to_str(df, column):
    output = df.withColumn((column + "_Str"), df[column].cast("string")).drop(column)
    output = output.withColumnRenamed((column + "_Str"), column)
    return output


def cast_all_to_float(df, feature_col):
    # column_list = []
    for column in feature_col:
        # try:
        df = cast_to_float(df, column)
        # except:
        # print(column)
        # column_list.append(column)
    return df


def feature_cols_labelling(df, feature_col):
    assembler = VectorAssembler(inputCols=feature_col, outputCol="features")
    df = cast_all_to_float(df, feature_col)
    output = assembler.transform(df)
    return output


def label_col_labelling(df, label_col):
    output = df.withColumn("label", df[label_col].cast("int")).drop(label_col)
    return output


def col_validation(df, column_list):
    cat_cols = []
    num_cols = []
    for column in column_list:
        if isinstance(df.schema[column].dataType, StringType):
            cat_cols.append(column)
        else:
            num_cols.append(column)
    return num_cols, cat_cols


def col_encoder(df, column, label_list):
    global label_l
    label_l = label_list
    output = df.withColumn((column + "_encoded"), to_index(column))
    return output


def col_decoder(df, column, label_list):
    global label_l
    label_l = label_list
    return df.withColumn((column + "_true_label"), to_label(column)) \
        .drop(column).withColumnRenamed((column + "_true_label"), column)


def append_col(df, column, name):
    df = cast_to_str(df, column)
    df = df.withColumn("col_var", F.lit(name))
    col_name = column + "_OneHot/" + str(name)
    df_out = df.withColumn(col_name, one_hot(column, "col_var"))
    return df_out, col_name


def one_hot_encoding(sc, df, source, column_list):
    message = []
    encoded_cat_cols = []
    for column in column_list:
        distinct_items = get_distincts(sc, source, column)
        if len(distinct_items) > 1024:
            message.append("Column [" + str(column) + "] has too many categories so it will not be taken account. ")
        else:
            for name in distinct_items:
                df, encoded_col = append_col(df, column, name)
                encoded_cat_cols.append(encoded_col)
    return df.drop("col_var"), message, encoded_cat_cols


def one_hot_encoding_p(sc, df, source, column_list, cat_col_dict):
    message = {}
    encoded_cat_cols = []
    for i in range(len(column_list)):
        column = column_list[i]
        distinct_items = get_distincts(sc, source, column)
        if len(distinct_items) > 1024:
            message["Warning: Too many categories"] = \
                "Column [" + str(column) + "] has too many categories so it will not be taken account. "

        distinct_items_trained = cat_col_dict[column]

        used = []
        len_di = len(distinct_items)
        len_dit = len(distinct_items_trained)
        reference_distinct_items = distinct_items_trained.copy()
        matched = [[0 for i in range(len_di)] for j in range(len_dit)]
        for i in range(len_di):
            for j in range(len_dit):
                sim = string_similar(distinct_items[i], distinct_items_trained[j])
                # todo > ???
                if sim > 0.66:
                    matched[j][i] = sim
        for j in range(len_dit):
            di = matched[j]
            for i in range(len_di):
                index = di.index(max(di))
                if index not in used and max(di) != 0:
                    reference_distinct_items[j] = distinct_items[index]
                    used.append(index)

        # print("see how it changes")
        # print(matched)
        # print(distinct_items)
        # print(distinct_items_trained)
        # print(reference_distinct_items)
        if distinct_items_trained != reference_distinct_items:
            col_message = "Column: " + column + ". Training encoded feature columns: " + str(distinct_items_trained) \
                          + "; encoded feature columns for prediction:" + str(reference_distinct_items)
            onehot_warning_key = "Warning: one-hot-encoding column(s) not matched"
            if onehot_warning_key not in message.keys():
                message[onehot_warning_key] = [col_message]
            else:
                message[onehot_warning_key].append(col_message)
        for name in reference_distinct_items:
            df, encoded_col = append_col(df, column, name)
            encoded_cat_cols.append(encoded_col)
    return df.drop("col_var"), message, encoded_cat_cols


def pca_to_df(predictions, k, spark, feature_cols):
    # todo optimization keep _record_id_ and join?
    new_features = {}
    preds = predictions.select("pca_features").collect()
    df = predictions
    df_columns = df.columns
    feature_cols = feature_cols.copy()
    feature_cols.append("pca_features")
    feature_cols.append("features")
    for col_name in feature_cols:
        df_columns.remove(col_name)
    # df_columns = df_columns - feature_cols

    col_vals = {}
    for col in df_columns:
        col_vals[col] = df.select(col).collect()
        new_features[col] = []

    for i in range(k):
        new_features["_PC_{}_".format(i + 1)] = []

    for j in range(df.count()):
        pc_list = preds[j][0]
        for col in df_columns:
            element = col_vals[col][j][0]
            new_features[col].append(element)
        for i in range(k):
            # TODO SAME COLUMN NAME CHECK
            new_features["_PC_{}_".format(i + 1)].append(float(pc_list[i]))

    df_pred_pd = pd.DataFrame.from_dict(new_features)
    df_pred = spark.createDataFrame(df_pred_pd)
    return df_pred


# TODO
def col_name_check(col, exist_cols):
    new_col_name = col
    count = 1
    while new_col_name in exist_cols:
        count += 1
        new_col_name = col + "_" + str(count)
    return new_col_name


# def oneHotEncoder(spark, df, df_col_list):
#
#
#     def add_id(df):
#         # 自增ID，用于之后的df横向合并
#         schema = df.schema.add(StructField("id", LongType()))
#         rdd = df.rdd.zipWithIndex()
#
#         def flat(l):
#             for k in l:
#                 if not isinstance(k, (list, tuple)):
#                     yield k
#                 else:
#                     yield from flat(k)
#
#         rdd = rdd.map(lambda x: list(flat(x)))
#         res = spark.createDataFrame(rdd, schema)
#         return res
#
#     df = add_id(df)
#     encoder_df = df.select(df_col_list)
#     encoder_df.show()
#     dataSet = encoder_df.rdd.map(list)
#
#     # Row对象代表的是DataFrame中的行
#     def f(x, df_col_list):
#         rel = {}
#         for i in range(len(df_col_list)):
#             rel[df_col_list[i]] = x[i]
#         return rel
#
#     trainingSet = dataSet.map(lambda x: Row(**f(x, df_col_list)))
#     # print(trainingSet.first())
#
#     for i in range(len(df_col_list)):
#         cate_dic = trainingSet.map(lambda fields: fields[i]).distinct().zipWithIndex().collectAsMap()
#
#         # print("cate_dic", cate_dic)
#         col_list = []
#         for j in cate_dic:
#             col_list.append(df_col_list[i] + "_code_" + str(j))
#
#         # print(col_list)
#
#         def extract_features(fields, cate_dic, end_index):
#             # 编码
#             category_index = cate_dic[fields[i]]
#             # print("3",fields[3])
#             category_features = [0] * len(cate_dic)
#             category_features[category_index] = 1.0
#
#             return category_features
#
#         def extract_label(fields):
#             return float(fields[1])
#
#         labelpoint_rdd = trainingSet.map(lambda r:
#                                          LabeledPoint(extract_label(r),
#                                                       extract_features(r, cate_dic, len(r) - 1)))
#         vector_df = spark.createDataFrame(labelpoint_rdd).select("features")
#
#         # vector_df.show()
#
#         def split_vector(vector_df, col_list):
#             # 拆分多维矩阵，也可以用pyspark.ml.feature.VectorSlice
#             schema = vector_df.schema
#             cols = vector_df.columns
#             for col in col_list:
#                 schema = schema.add(col, DoubleType(), True)
#             split_df = spark.createDataFrame(
#                 vector_df.rdd.map(lambda row: [row[i] for i in cols] + row.features.tolist()), schema)
#             split_df = split_df.drop("features")
#             return split_df
#
#         res0 = split_vector(vector_df, col_list)
#
#         res = add_id(res0)
#         df = df.join(res, ["id"], "left")
#
#     return df


def multiclass_eval(predictions):
    accuracy = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction",
                                                 metricName="accuracy").evaluate(predictions)

    precision = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction",
                                                  metricName="weightedPrecision").evaluate(predictions)

    recall = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction",
                                               metricName="weightedRecall").evaluate(predictions)

    F1 = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction",
                                           metricName="f1").evaluate(predictions)

    return accuracy, precision, recall, F1


# TODO optimization
def get_confusion_matrix(label_list, predictions, para_list):
    confusion_matrix = []
    for label in label_list:
        cm_row = []
        truths = predictions.filter(predictions[para_list["label_col"]] == label)
        for label_pred in label_list:
            cm_row.append(truths.filter(predictions["prediction"] == label_pred).count())
        confusion_matrix.append(cm_row)
    return confusion_matrix